{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.0109372 ,  0.0257944 ,  0.00579614, -0.01564281], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "\n",
    "        #一局游戏最多走N步\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAATw0lEQVR4nO3db0xb570H8K8N2AmBYwIUuyx44d5VS1H+bCMJnPbFpsUNy9C0rLzopqhjVZSqmYmaMkUaum2qRJWIshftsqXkxe6SvmmZmG66FaWtuKQlmuJCQ8e9hBK0XbWC22A7CfUx0GCD/bsvejmrG5JhAn7s+PuRThQ/z2P7dw7xN+c8j40tIiIgIlLAqroAIspeDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlJGWQCdPHkS69evx6pVq1BTU4O+vj5VpRCRIkoC6A9/+AOam5vx/PPP44MPPsCWLVtQV1eHYDCoohwiUsSi4sOoNTU12LZtG377298CAOLxOCoqKnDgwAH88pe/THU5RKRIbqqfMBqNor+/Hy0tLWab1WqFx+OBz+db8D6RSASRSMS8HY/HMTExgZKSElgslhWvmYiSIyKYnJxEeXk5rNbbX2ilPICuX7+OWCwGp9OZ0O50OnHlypUF79Pa2oojR46kojwiWkZjY2NYt27dbftTHkBL0dLSgubmZvO2YRhwu90YGxuDpmkKKyOihYTDYVRUVKCwsPCO41IeQKWlpcjJyUEgEEhoDwQCcLlcC97HbrfDbrff0q5pGgOIKI39symSlK+C2Ww2VFdXo7u722yLx+Po7u6GruupLoeIFFJyCdbc3IzGxkZs3boV27dvx0svvYTp6Wk88cQTKsohIkWUBNBjjz2Ga9eu4fDhw/D7/fjGN76Bt95665aJaSK6tyl5H9DdCofDcDgcMAyDc0BEaWixr1F+FoyIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKJB1AFy5cwA9+8AOUl5fDYrHg9ddfT+gXERw+fBj3338/Vq9eDY/Hg7/97W8JYyYmJrBnzx5omoaioiLs3bsXU1NTd7UjRJR5kg6g6elpbNmyBSdPnlyw//jx4zhx4gROnTqF3t5erFmzBnV1dZiZmTHH7NmzB0NDQ+jq6kJnZycuXLiAJ598cul7QUSZSe4CADl79qx5Ox6Pi8vlkl/96ldmWygUErvdLq+99pqIiHz44YcCQN5//31zzJtvvikWi0U++eSTRT2vYRgCQAzDuJvyiWiFLPY1uqxzQB999BH8fj88Ho/Z5nA4UFNTA5/PBwDw+XwoKirC1q1bzTEejwdWqxW9vb0LPm4kEkE4HE7YiCjzLWsA+f1+AIDT6UxodzqdZp/f70dZWVlCf25uLoqLi80xX9ba2gqHw2FuFRUVy1k2ESmSEatgLS0tMAzD3MbGxlSXRETLYFkDyOVyAQACgUBCeyAQMPtcLheCwWBC/9zcHCYmJswxX2a326FpWsJGRJlvWQOosrISLpcL3d3dZls4HEZvby90XQcA6LqOUCiE/v5+c8z58+cRj8dRU1OznOUQUZrLTfYOU1NT+Pvf/27e/uijjzAwMIDi4mK43W4cPHgQL7zwAh544AFUVlbiueeeQ3l5OXbv3g0AePDBB/G9730P+/btw6lTpzA7O4umpib8+Mc/Rnl5+bLtGBFlgGSX19555x0BcMvW2NgoIp8vxT/33HPidDrFbrfLjh07ZGRkJOExbty4IT/5yU+koKBANE2TJ554QiYnJ5d9iY+I1Fjsa9QiIqIw/5YkHA7D4XDAMAzOBxGlocW+RjNiFYyI7k0MICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEiZpL+Wh2i5iMRhjA0hFvnMbMu/76tYXbTwF1TSvYcBRMpIPI7/7f0P3Jz4xGxzP/QYAyiL8BKMFJr/WjnKVgwgUof5k/UYQKSQQJhAWY0BRMqICM+AshwDiBRjAmUzBhCpIwyfbJdUALW2tmLbtm0oLCxEWVkZdu/ejZGRkYQxMzMz8Hq9KCkpQUFBARoaGhAIBBLGjI6Oor6+Hvn5+SgrK8OhQ4cwNzd393tDGUYYQlkuqQDq6emB1+vFe++9h66uLszOzmLnzp2Ynp42xzzzzDN444030NHRgZ6eHly9ehWPPvqo2R+LxVBfX49oNIqLFy/ilVdewZkzZ3D48OHl2yvKCCKcgs56cheCwaAAkJ6eHhERCYVCkpeXJx0dHeaY4eFhASA+n09ERM6dOydWq1X8fr85pq2tTTRNk0gksqjnNQxDAIhhGHdTPik2OzMt//Xav0nfqX3m5v/v/1RdFi2Dxb5G72oOyDAMAEBxcTEAoL+/H7Ozs/B4POaYDRs2wO12w+fzAQB8Ph82bdoEp9Npjqmrq0M4HMbQ0NCCzxOJRBAOhxM2uhdwFSzbLTmA4vE4Dh48iIcffhgbN24EAPj9fthsNhQVFSWMdTqd8Pv95pgvhs98/3zfQlpbW+FwOMytoqJiqWVTOhG+EzHbLTmAvF4vLl++jPb29uWsZ0EtLS0wDMPcxsbGVvw5aeUJwEnoLLekD6M2NTWhs7MTFy5cwLp168x2l8uFaDSKUCiUcBYUCATgcrnMMX19fQmPN79KNj/my+x2O+x2+1JKpXTGSeisl9QZkIigqakJZ8+exfnz51FZWZnQX11djby8PHR3d5ttIyMjGB0dha7rAABd1zE4OIhgMGiO6erqgqZpqKqqupt9oYzDS7Bsl9QZkNfrxauvvoo//elPKCwsNOdsHA4HVq9eDYfDgb1796K5uRnFxcXQNA0HDhyAruuora0FAOzcuRNVVVV4/PHHcfz4cfj9fjz77LPwer08y8k2zJ+sl1QAtbW1AQC+853vJLSfPn0aP/vZzwAAL774IqxWKxoaGhCJRFBXV4eXX37ZHJuTk4POzk7s378fuq5jzZo1aGxsxNGjR+9uTyjjCBMo61lEMm8WMBwOw+FwwDAMaJqmuhxaouh0CMOvH0N0asJscz/0GJybdiisipbDYl+j/CwYKZWB///RMmIAkToMn6zHACKF+GHUbMcAImXkC39SdmIAkTo8+8l6DCBSRiR+6yS0xaKmGFKCAUTKzN0MY25myrxtsebCrpUprIhSjQFEykg8Dkj8Hw0WC6w5/K7MbMIAovTCS7CswgCiNMMAyiYMIEorFp4BZRUGEKUXBlBWYQBR2vj87IcBlE0YQJRWeAmWXRhAlF4YQFmFAURpxcJLsKzCAKL0wjOgrMIAojRiYQBlGQYQpRVegmUXBhClF54BZRUGEKUPCxhAWYYBRGmFl2DZhQFEaYST0NmGAUTphQGUVfjbn2jFxONxTE5O3va7v25OTSXcFhFMTk4hgtCC43NyclBQUMCPa9xDGEC0Yj799FPs2LED169fX7D/wYq1eKGx1pz3+WxqCrt27UIgdHPB8Zs3b8af//xn5Obyn+29gj9JWjHxeBzj4+MIBoML9peumsPE7P34+OYW5Fpmcb/VB7/fj6s3phYc73K5VrJcUiCpOaC2tjZs3rwZmqZB0zTouo4333zT7J+ZmYHX60VJSQkKCgrQ0NCAQCCQ8Bijo6Oor69Hfn4+ysrKcOjQIczNzS3P3lBGmY4VYSC8A9dmv4rx6L9iILwDkbhddVmUQkkF0Lp163Ds2DH09/fj0qVL+O53v4sf/vCHGBoaAgA888wzeOONN9DR0YGenh5cvXoVjz76qHn/WCyG+vp6RKNRXLx4Ea+88grOnDmDw4cPL+9eUUaIyipEZfX/37Lgs3gh5uI5SmuiFJO7tHbtWvnd734noVBI8vLypKOjw+wbHh4WAOLz+URE5Ny5c2K1WsXv95tj2traRNM0iUQii35OwzAEgBiGcbfl0woKBoNSVlYm+PzrT2/ZHvyXr8qJY6/LkRd65egLvfLS0X+XEod22/HV1dUyOzurerdoERb7Gl3yHFAsFkNHRwemp6eh6zr6+/sxOzsLj8djjtmwYQPcbjd8Ph9qa2vh8/mwadMmOJ1Oc0xdXR3279+PoaEhfPOb30yqhitXrqCgoGCpu0Ar7NNPP73j5fWNieu42N2KQLQSOZZZFFv/BzdnFp6ABj6/xB8eHkZODs+S0t3U1MLzeF+WdAANDg5C13XMzMygoKAAZ8+eRVVVFQYGBmCz2VBUVJQw3ul0wu/3AwD8fn9C+Mz3z/fdTiQSQSQSMW+Hw2EAgGEYnD9KY+Fw+LZL8AAQDE2jvasXQO+iHi8Wi8EwDFitfPtaupuenl7UuKQD6Otf/zoGBgZgGAb++Mc/orGxET09PUkXmIzW1lYcOXLklvaamhpomraiz01Ld+3aNeTl5S3b461Zswa1tbVchs8A8ycJ/0zS/5XYbDZ87WtfQ3V1NVpbW7Flyxb8+te/hsvlQjQaRSgUShgfCATM5VOXy3XLqtj87Tstsba0tMAwDHMbGxtLtmwiSkN3fS4bj8cRiURQXV2NvLw8dHd3m30jIyMYHR2FrusAAF3XMTg4mPC+kK6uLmiahqqqqts+h91uN5f+5zciynxJncu2tLRg165dcLvdmJycxKuvvop3330Xb7/9NhwOB/bu3Yvm5mYUFxdD0zQcOHAAuq6jtrYWALBz505UVVXh8ccfx/Hjx+H3+/Hss8/C6/XCbuf7P4iyTVIBFAwG8dOf/hTj4+NwOBzYvHkz3n77bTzyyCMAgBdffBFWqxUNDQ2IRCKoq6vDyy+/bN4/JycHnZ2d2L9/P3Rdx5o1a9DY2IijR48u715RWrBYLCgsLMTMzMyyPB5XPO89FrnTMkWaCofDcDgcMAyDl2NpLBaLwe/3Ix6PL8vj2Ww2lJWV8cOoGWCxr1EuJ9CKycnJwVe+8hXVZVAa4xsqiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTK5qgtYChEBAITDYcWVENFC5l+b86/V28nIALpx4wYAoKKiQnElRHQnk5OTcDgct+3PyAAqLi4GAIyOjt5x5yhROBxGRUUFxsbGoGma6nIyAo/Z0ogIJicnUV5efsdxGRlAVuvnU1cOh4P/KJZA0zQetyTxmCVvMScHnIQmImUYQESkTEYGkN1ux/PPPw+73a66lIzC45Y8HrOVZZF/tk5GRLRCMvIMiIjuDQwgIlKGAUREyjCAiEiZjAygkydPYv369Vi1ahVqamrQ19enuiRlWltbsW3bNhQWFqKsrAy7d+/GyMhIwpiZmRl4vV6UlJSgoKAADQ0NCAQCCWNGR0dRX1+P/Px8lJWV4dChQ5ibm0vlrihz7NgxWCwWHDx40GzjMUsRyTDt7e1is9nk97//vQwNDcm+ffukqKhIAoGA6tKUqKurk9OnT8vly5dlYGBAvv/974vb7ZapqSlzzFNPPSUVFRXS3d0tly5dktraWnnooYfM/rm5Odm4caN4PB7561//KufOnZPS0lJpaWlRsUsp1dfXJ+vXr5fNmzfL008/bbbzmKVGxgXQ9u3bxev1mrdjsZiUl5dLa2urwqrSRzAYFADS09MjIiKhUEjy8vKko6PDHDM8PCwAxOfziYjIuXPnxGq1it/vN8e0tbWJpmkSiURSuwMpNDk5KQ888IB0dXXJt7/9bTOAeMxSJ6MuwaLRKPr7++HxeMw2q9UKj8cDn8+nsLL0YRgGgH98YLe/vx+zs7MJx2zDhg1wu93mMfP5fNi0aROcTqc5pq6uDuFwGENDQymsPrW8Xi/q6+sTjg3AY5ZKGfVh1OvXryMWiyX80AHA6XTiypUriqpKH/F4HAcPHsTDDz+MjRs3AgD8fj9sNhuKiooSxjqdTvj9fnPMQsd0vu9e1N7ejg8++ADvv//+LX08ZqmTUQFEd+b1enH58mX85S9/UV1KWhsbG8PTTz+Nrq4urFq1SnU5WS2jLsFKS0uRk5Nzy2pEIBCAy+VSVFV6aGpqQmdnJ9555x2sW7fObHe5XIhGowiFQgnjv3jMXC7Xgsd0vu9e09/fj2AwiG9961vIzc1Fbm4uenp6cOLECeTm5sLpdPKYpUhGBZDNZkN1dTW6u7vNtng8ju7ubui6rrAydUQETU1NOHv2LM6fP4/KysqE/urqauTl5SUcs5GREYyOjprHTNd1DA4OIhgMmmO6urqgaRqqqqpSsyMptGPHDgwODmJgYMDctm7dij179ph/5zFLEdWz4Mlqb28Xu90uZ86ckQ8//FCefPJJKSoqSliNyCb79+8Xh8Mh7777royPj5vbZ599Zo556qmnxO12y/nz5+XSpUui67roum72zy8p79y5UwYGBuStt96S++67L6uWlL+4CibCY5YqGRdAIiK/+c1vxO12i81mk+3bt8t7772nuiRlACy4nT592hxz8+ZN+fnPfy5r166V/Px8+dGPfiTj4+MJj/Pxxx/Lrl27ZPXq1VJaWiq/+MUvZHZ2NsV7o86XA4jHLDX46ziISJmMmgMionsLA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlLm/wAjNqMnhPf9bwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "POOOk15_K6KA"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env.observation_space= Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)\n",
      "env.action_space= Discrete(2)\n",
      "state= [-0.0293525  -0.04349425  0.02145252 -0.02283522]\n",
      "action= 0\n",
      "next_state= [-0.03022238 -0.23891717  0.02099581  0.27653828]\n",
      "reward= 1.0\n",
      "done= False\n"
     ]
    }
   ],
   "source": [
    "#认识游戏环境\n",
    "def test_env():\n",
    "    print('env.observation_space=', env.observation_space)\n",
    "    print('env.action_space=', env.action_space)\n",
    "\n",
    "    state = env.reset()\n",
    "    action = env.action_space.sample()\n",
    "    next_state, reward, done, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    print('action=', action)\n",
    "    print('next_state=', next_state)\n",
    "    print('reward=', reward)\n",
    "    print('done=', done)\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "action= tensor(1)\n",
      "log_prob= tensor(-1.6094)\n"
     ]
    }
   ],
   "source": [
    "#测试torch.distributions.Categorical\n",
    "def test_dist():\n",
    "    import torch\n",
    "\n",
    "    #创建分布\n",
    "    dist = torch.distributions.Categorical(torch.FloatTensor([0.1, 0.2, 0.7]))\n",
    "\n",
    "    #从分布中采样\n",
    "    action = dist.sample()\n",
    "    print('action=', action)\n",
    "\n",
    "    #计算概率的log\n",
    "    log_prob = dist.log_prob(action)\n",
    "    print('log_prob=', log_prob)\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "test_dist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "Ho_UHf49N9i4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, tensor([-0.3408], grad_fn=<SqueezeBackward1>))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sequential = torch.nn.Sequential(torch.nn.Linear(4, 16),\n",
    "                                              torch.nn.ReLU(),\n",
    "                                              torch.nn.Linear(16, 2),\n",
    "                                              torch.nn.Softmax(dim=1))\n",
    "\n",
    "    def forward(self, state):\n",
    "        #[4] -> [1, 4]\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "\n",
    "        #计算当前state下各个动作的概率\n",
    "        #[1, 4] -> [1, 2]\n",
    "        probs = self.sequential(state)\n",
    "\n",
    "        #以概率创建分布\n",
    "        dist = torch.distributions.Categorical(probs)\n",
    "\n",
    "        #在概率中采样,获得action\n",
    "        #scala\n",
    "        action = dist.sample()\n",
    "\n",
    "        #求动作的概率对数\n",
    "        #scala\n",
    "        log_prob = dist.log_prob(action)\n",
    "\n",
    "        return action.item(), log_prob\n",
    "\n",
    "\n",
    "model = Model()\n",
    "model([1, 2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5 693.1473388671875\n",
      "0.9 294.32183837890625\n",
      "0.9999 0.100021593272686\n"
     ]
    }
   ],
   "source": [
    "#测试概率和loss之间的关系\n",
    "def test_loss(p0):\n",
    "    #以概率创建分布\n",
    "    dist = torch.distributions.Categorical(torch.FloatTensor([p0, 1 - p0]))\n",
    "\n",
    "    log_probs = []\n",
    "    for _ in range(1000):\n",
    "        #在概率中采样,获得action\n",
    "        action = dist.sample()\n",
    "\n",
    "        #求动作的概率对数\n",
    "        log_prob = dist.log_prob(action)\n",
    "\n",
    "        log_probs.append(log_prob)\n",
    "\n",
    "    #动作的概率,取值是0-1,取对数后是-inf到0,符号取反之后是0到inf\n",
    "    #这意味着loss越接近0,动作的概率越高,loss越大,动作的概率越低\n",
    "    #动作本身也是从概率采样而来,所以概率越倾斜,loss越小.\n",
    "    #也就是说[0.5,0.5]将导致大loss, [0.0,1.0]将导致小loss\n",
    "    log_probs = torch.FloatTensor(log_probs)\n",
    "    log_probs = -log_probs\n",
    "    loss = log_probs.sum().item()\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "for i in [0.5, 0.9, 0.9999]:\n",
    "    print(i, test_loss(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20.65"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#测试一局游戏,得到reward_sum\n",
    "def test(play=False):\n",
    "    state = env.reset()\n",
    "    reward_sum = 0\n",
    "    done = False\n",
    "\n",
    "    while not done:\n",
    "        action, _ = model(state)\n",
    "        state, reward, done, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        if play and random.random() < 0.2:\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "sum([test() for _ in range(20)]) / 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "NCNvyElRStWG"
   },
   "outputs": [],
   "source": [
    "#训练\n",
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "    #训练N局游戏\n",
    "    for i in range(5000):\n",
    "        #得到一局游戏的数据\n",
    "        rewards = []\n",
    "        log_probs = []\n",
    "        state = env.reset()\n",
    "        done = False\n",
    "        while not done:\n",
    "            action, log_prob = model(state)\n",
    "            state, reward, done, _ = env.step(action)\n",
    "\n",
    "            #记录下rewards和log_probs\n",
    "            rewards.append(reward)\n",
    "            log_probs.append(log_prob)\n",
    "\n",
    "        #[steps]\n",
    "        rewards = torch.FloatTensor(rewards)\n",
    "        #[steps]\n",
    "        log_probs = torch.cat(log_probs)\n",
    "\n",
    "        #对rewards进行decay后求和\n",
    "        decay = torch.arange(len(rewards))\n",
    "        #这里gamma写1.0意味着不decay\n",
    "        decay = 1.0**decay\n",
    "        rewards *= decay\n",
    "        rewards = rewards.sum()\n",
    "\n",
    "        #考虑rewards和log_probs两部分loss,让两者相乘即可\n",
    "        loss = rewards * -log_probs.sum()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 200 == 0:\n",
    "            print(i, sum([test() for _ in range(5)]) / 5)\n",
    "\n",
    "    torch.save(model, 'save/3.Reinforce_CartPole')\n",
    "\n",
    "\n",
    "#train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "3FamHmxyhBEU"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVbUlEQVR4nO3df2wU550G8Gd217u2Wc8a2/FuXOwDtbTEx4+0BswklVolLg6xotL4j7ZCqYsQudA1grhCV58SoqRpHVGpaZMSOOlaiFQlVG6PVrEIxDKJUcSCwakrY4JF7uDsBnY34HjWNuzPee8P4gkLdvDaxq8Hno80EvO+73q/M3gfz74zO6sIIQSIiCSwyS6AiO5eDCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpJGWgDt3LkT8+fPR3Z2NiorK9HR0SGrFCKSREoA/elPf0JDQwOee+45fPDBB1i2bBmqq6sRDodllENEkigyPoxaWVmJFStW4He/+x0AwDAMlJaWYvPmzfjZz3420+UQkSSOmX7CeDyOzs5ONDY2mm02mw1VVVUIBAJjPiYWiyEWi5nrhmFgYGAAhYWFUBTlttdMRJkRQmBoaAglJSWw2cZ/ozXjAXTp0iWkUil4vd60dq/XizNnzoz5mKamJjz//PMzUR4RTaP+/n7Mmzdv3P4ZD6DJaGxsRENDg7mu6zrKysrQ398PVVUlVkZEY4lEIigtLUVeXt4XjpvxACoqKoLdbkcoFEprD4VC8Pl8Yz7G5XLB5XLd1K6qKgOIaBa71RTJjJ8FczqdqKioQFtbm9lmGAba2tqgadpMl0NEEkl5C9bQ0IC6ujosX74cK1euxG9+8xuMjIxg/fr1MsohIkmkBND3v/99fPLJJ9i+fTuCwSDuv/9+HDx48KaJaSK6s0m5DmiqIpEIPB4PdF3nHBDRLDTR1yg/C0ZE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpMg6gI0eO4LHHHkNJSQkURcFf//rXtH4hBLZv3457770XOTk5qKqqwtmzZ9PGDAwMYN26dVBVFfn5+diwYQOGh4entCFEZD0ZB9DIyAiWLVuGnTt3jtm/Y8cOvPLKK9i9ezeOHz+OOXPmoLq6GtFo1Byzbt069PT0oLW1FS0tLThy5AiefPLJyW8FEVmTmAIAYv/+/ea6YRjC5/OJX/3qV2bb4OCgcLlc4s033xRCCHH69GkBQJw4ccIc8/bbbwtFUcTHH388oefVdV0AELquT6V8IrpNJvoandY5oHPnziEYDKKqqsps83g8qKysRCAQAAAEAgHk5+dj+fLl5piqqirYbDYcP358zJ8bi8UQiUTSFiKyvmkNoGAwCADwer1p7V6v1+wLBoMoLi5O63c4HCgoKDDH3KipqQkej8dcSktLp7NsIpLEEmfBGhsboeu6ufT398suiYimwbQGkM/nAwCEQqG09lAoZPb5fD6Ew+G0/mQyiYGBAXPMjVwuF1RVTVuIyPqmNYAWLFgAn8+HtrY2sy0SieD48ePQNA0AoGkaBgcH0dnZaY45fPgwDMNAZWXldJZDRLOcI9MHDA8P46OPPjLXz507h66uLhQUFKCsrAxbt27Fiy++iIULF2LBggV49tlnUVJSgrVr1wIA7rvvPjzyyCPYuHEjdu/ejUQigfr6evzgBz9ASUnJtG0YEVlApqfX3n33XQHgpqWurk4Ice1U/LPPPiu8Xq9wuVzi4YcfFr29vWk/4/Lly+KHP/yhcLvdQlVVsX79ejE0NDTtp/iISI6JvkYVIYSQmH+TEolE4PF4oOs654OIZqGJvkYtcRaMiO5MDCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhImoy/lodINmGkkLgSQTI2gmR0GKlEDJ5598HmcMoujTLEACLLGQ6fw0fv7IJIJWGkkrDZs3Df2n9Hztx7ZZdGGeJbMLIeIZCMjiAVvwqRSsBIxhDVQ7d+HM06DCCyHJd6D7JyPea6MFJIjAzCgl9xd9djAJHlOFy5sGe50tpS8auSqqGpYACR9Sg2KDZ7WtPVTy/i2reEk5UwgMhyFMWG3KJ/SWuLDgaZPxbEACLrURS48grTmoSRgjBSkgqiycoogJqamrBixQrk5eWhuLgYa9euRW9vb9qYaDQKv9+PwsJCuN1u1NbWIhRKP0PR19eHmpoa5Obmori4GNu2bUMymZz61tBdQVEU2OzpV5Ako8NIRoclVUSTlVEAtbe3w+/349ixY2htbUUikcDq1asxMjJijnn66afx1ltvobm5Ge3t7bhw4QIef/xxsz+VSqGmpgbxeBxHjx7F66+/jr1792L79u3Tt1V0x8ueW5I2DxS/oiNxRZdYEU2KmIJwOCwAiPb2diGEEIODgyIrK0s0NzebYz788EMBQAQCASGEEAcOHBA2m00Eg0FzzK5du4SqqiIWi03oeXVdFwCErutTKZ8sbOTyP8XJ//KLjt0bry3/+W8icvEj2WXRZyb6Gp3SHJCuX/uLU1BQAADo7OxEIpFAVVWVOWbRokUoKytDIBAAAAQCASxZsgRer9ccU11djUgkgp6enjGfJxaLIRKJpC10d7M7XACUzxuEQHzokrR6aHImHUCGYWDr1q148MEHsXjxYgBAMBiE0+lEfn5+2liv14tgMGiOuT58RvtH+8bS1NQEj8djLqWlpZMtm+4Qit0BuzP7uhaB4fD/SquHJmfSAeT3+3Hq1Cns27dvOusZU2NjI3RdN5f+/v7b/pw0uzmy3cjO96W1CcPg1dAWM6kPo9bX16OlpQVHjhzBvHnzzHafz4d4PI7BwcG0o6BQKASfz2eO6ejoSPt5o2fJRsfcyOVyweVyjdlHdyfFZr/pauhkdBgQBqDYx3kUzTYZHQEJIVBfX4/9+/fj8OHDWLBgQVp/RUUFsrKy0NbWZrb19vair68PmqYBADRNQ3d3N8LhsDmmtbUVqqqivLx8KttCdxm7Mydt/erABRjJuKRqaDIyOgLy+/1444038Le//Q15eXnmnI3H40FOTg48Hg82bNiAhoYGFBQUQFVVbN68GZqmYdWqVQCA1atXo7y8HE888QR27NiBYDCIZ555Bn6/n0c5lBG398u4fPa4uZ6KX+HFiBaTUQDt2rULAPDtb387rX3Pnj348Y9/DAB4+eWXYbPZUFtbi1gshurqarz22mvmWLvdjpaWFmzatAmapmHOnDmoq6vDCy+8MLUtobtOVq6ati6MFFLxKBzZbkkVUaYUYcFZu0gkAo/HA13XoarqrR9Ad6TBvlM4e/BV4LNfYcXmwMJH/PCU/qvkymiir1F+FowsK9tTDIdrjrkujCQS0SGeCbMQBhBZliPbfdN9oI14VFI1NBkMILIsxWYDFCWt7cplXiNmJQwgsizFZke2J/2q+qgeHmc0zUYMILIsxeYY42ro1LWLEckSGEBkaTfOASWjI0glOA9kFQwgsizFvDPi5/NA8eHLSFwdklcUZYQBRJaWU/Cla5PRnxGGAZHi3TWtggFEluZw5iD9vkAG74xoIQwgsjRblivtvkBCGBgJn5dXEGWEAUSW5sjOg9NdkNaWSsZ4NbRFMIDI0mwOJ+xZ2Wlt1+4LxACyAgYQWd6N35I68sl5CIMT0VbAACLLc/u+nLaeikcheDGiJTCAyPJunAMSqSRSiZikaigTDCCyNEVRoNhu+JbU2DBig6FxHkGzCQOILC8n3wfbdRPR146ArkqsiCaKAUSWlzUnHzZ7VlpbivcFsgQGEFmeze6AcsN9gXhbDmtgAJHlKfYsuDzFaW1XLv9TUjWUCQYQWZ7NngVXXlFamzASEAZPxc92DCCyPkWB7YZvSU1cGYKR5Kn42Y4BRJanKAocrty0tlgkhGTsiqSKaKIYQHRHcHu/jOtvyyEMwW9JtYCMvhmVSKYrV64gHh/7u9+jSVzLn88+gyqEgcjlIKLCOeZ4AHC73XA4+BKQiXufLOOXv/wl9u7dO2bffG8eXlqvwTkaKKkEXvyPLfjv98+OOV5RFPzlL3/BypUrb1O1NBEMILIMXdfx8ccfj9k39KkL/ZeWI5H3TURShfiS6yz0wRPjjlcUZdyjKZo5Gc0B7dq1C0uXLoWqqlBVFZqm4e233zb7o9Eo/H4/CgsL4Xa7UVtbi1Ao/TM5fX19qKmpQW5uLoqLi7Ft2zYkk7x1Ak3NSDSBzoEVOB9dgoHEl9Az/E3EXfdBufVDSaKMAmjevHl46aWX0NnZiZMnT+Khhx7Cd7/7XfT09AAAnn76abz11ltobm5Ge3s7Lly4gMcff9x8fCqVQk1NDeLxOI4ePYrXX38de/fuxfbt26d3q+iuYwhgKJGP0YloAw4snL8QdjvPs8xmGf3vPPbYY3j00UexcOFCfPWrX8UvfvELuN1uHDt2DLqu4/e//z1+/etf46GHHkJFRQX27NmDo0eP4tixYwCAd955B6dPn8Yf//hH3H///VizZg1+/vOfY+fOnTwcpqkRAiOXjsGOBAABJz7FPc6x337R7DHpOaBUKoXm5maMjIxA0zR0dnYikUigqqrKHLNo0SKUlZUhEAhg1apVCAQCWLJkCbzez79Ot7q6Gps2bUJPTw++/vWvZ1TDmTNn4Ha7J7sJZDEDAwPj9gkIDPS3YSg/hn/8nwH9k3/gnxf+B8nU+FdDnz9/HgUFBeP20+QNDw9PaFzGAdTd3Q1N0xCNRuF2u7F//36Ul5ejq6sLTqcT+fn5aeO9Xi+CwSAAIBgMpoXPaP9o33hisRhisc+vao1EIgCuTUpy/ujucf3vwFgOdZzFOyfOImVM7H7Qw8PDGBwcnIbK6EYjIyMTGpdxAH3ta19DV1cXdF3Hn//8Z9TV1aG9vT3jAjPR1NSE559//qb2yspKqKp6W5+bZo8333zzC/sNIczrgCZi8eLFeOCBB6ZYFY1l9CDhVjKeoXM6nfjKV76CiooKNDU1YdmyZfjtb38Ln8+HeDx+01+UUCgEn88HAPD5fDedFRtdHx0zlsbGRui6bi79/f2Zlk1Es9CUTxEYhoFYLIaKigpkZWWhra3N7Ovt7UVfXx80TQMAaJqG7u5uhMOf36ultbUVqqqivLx83OdwuVzmqf/RhYisL6O3YI2NjVizZg3KysowNDSEN954A++99x4OHToEj8eDDRs2oKGhAQUFBVBVFZs3b4amaVi1ahUAYPXq1SgvL8cTTzyBHTt2IBgM4plnnoHf74fL5brFsxPRnSajAAqHw/jRj36EixcvwuPxYOnSpTh06BC+853vAABefvll2Gw21NbWIhaLobq6Gq+99pr5eLvdjpaWFmzatAmapmHOnDmoq6vDCy+8ML1bRXek7OzsaTv6VRQFdrv91gPptlKEBb/DNhKJwOPxQNd1vh27i3z66acTPr07EcXFxTzyvk0m+hrlZ8HIMubOnYu5c+fKLoOmEa9TJyJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNI4ZBcwGUIIAEAkEpFcCRGNZfS1OfpaHY8lA+jy5csAgNLSUsmVENEXGRoagsfjGbffkgFUUFAAAOjr6/vCjaN0kUgEpaWl6O/vh6qqssuxBO6zyRFCYGhoCCUlJV84zpIBZLNdm7ryeDz8pZgEVVW53zLEfZa5iRwccBKaiKRhABGRNJYMIJfLheeeew4ul0t2KZbC/ZY57rPbSxG3Ok9GRHSbWPIIiIjuDAwgIpKGAURE0jCAiEgaSwbQzp07MX/+fGRnZ6OyshIdHR2yS5KmqakJK1asQF5eHoqLi7F27Vr09vamjYlGo/D7/SgsLITb7UZtbS1CoVDamL6+PtTU1CA3NxfFxcXYtm0bksnkTG6KNC+99BIURcHWrVvNNu6zGSIsZt++fcLpdIo//OEPoqenR2zcuFHk5+eLUCgkuzQpqqurxZ49e8SpU6dEV1eXePTRR0VZWZkYHh42xzz11FOitLRUtLW1iZMnT4pVq1aJBx54wOxPJpNi8eLFoqqqSvz9738XBw4cEEVFRaKxsVHGJs2ojo4OMX/+fLF06VKxZcsWs537bGZYLoBWrlwp/H6/uZ5KpURJSYloamqSWNXsEQ6HBQDR3t4uhBBicHBQZGVliebmZnPMhx9+KACIQCAghBDiwIEDwmaziWAwaI7ZtWuXUFVVxGKxmd2AGTQ0NCQWLlwoWltbxbe+9S0zgLjPZo6l3oLF43F0dnaiqqrKbLPZbKiqqkIgEJBY2eyh6zqAzz+w29nZiUQikbbPFi1ahLKyMnOfBQIBLFmyBF6v1xxTXV2NSCSCnp6eGax+Zvn9ftTU1KTtG4D7bCZZ6sOoly5dQiqVSvtPBwCv14szZ85Iqmr2MAwDW7duxYMPPojFixcDAILBIJxOJ/Lz89PGer1eBINBc8xY+3S07060b98+fPDBBzhx4sRNfdxnM8dSAURfzO/349SpU3j//fdllzKr9ff3Y8uWLWhtbUV2drbscu5qlnoLVlRUBLvdftPZiFAoBJ/PJ6mq2aG+vh4tLS149913MW/ePLPd5/MhHo9jcHAwbfz1+8zn8425T0f77jSdnZ0Ih8P4xje+AYfDAYfDgfb2drzyyitwOBzwer3cZzPEUgHkdDpRUVGBtrY2s80wDLS1tUHTNImVySOEQH19Pfbv34/Dhw9jwYIFaf0VFRXIyspK22e9vb3o6+sz95mmaeju7kY4HDbHtLa2QlVVlJeXz8yGzKCHH34Y3d3d6OrqMpfly5dj3bp15r+5z2aI7FnwTO3bt0+4XC6xd+9ecfr0afHkk0+K/Pz8tLMRd5NNmzYJj8cj3nvvPXHx4kVzuXLlijnmqaeeEmVlZeLw4cPi5MmTQtM0oWma2T96Snn16tWiq6tLHDx4UNxzzz131Snl68+CCcF9NlMsF0BCCPHqq6+KsrIy4XQ6xcqVK8WxY8dklyQNgDGXPXv2mGOuXr0qfvKTn4i5c+eK3Nxc8b3vfU9cvHgx7eecP39erFmzRuTk5IiioiLx05/+VCQSiRneGnluDCDus5nB23EQkTSWmgMiojsLA4iIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpLm/wGaYjeUM78xkAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = torch.load('save/3.Reinforce_CartPole')\n",
    "\n",
    "#试玩\n",
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPoxzAc/iqqWPY8sm+LcHBH",
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "unit5",
   "private_outputs": true,
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
